# PANet 将网络的有效输出特征层和SPP结构的输出相融合

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from CSPDarknet53 import conv_block  # 网络模型和标准卷积块
from SPP import spp  # 导入spp加强特征提取模块

# 5次卷积操作提取特征减少参数量
def five_conv(x, filters):
    x = conv_block(x, filters, (1,1), strides=1)
    x = conv_block(x, filters*2, (3,3), strides=1)
    x = conv_block(x, filters, (1,1), strides=1)    
    x = conv_block(x, filters*2, (3,3), strides=1)
    x = conv_block(x, filters, (1,1), strides=1)
    return x


def panet(inputs):

    # 获得网络的三个有效输出特征层
    feat1, feat2, p5 = spp(inputs)

    #（1）
    # 对spp结构的输出进行卷积和上采样
    # [13,13,512]==>[13,13,256]==>[26,26,256]
    p5_upsample = conv_block(p5, filters=256, kernel_size=(1,1), strides=1)
    p5_upsample = layers.UpSampling2D(size=(2,2))(p5_upsample)

    # 对feat2特征层卷积后再与p5_upsample堆叠
    # [26,26,512]==>[26,26,256]==>[26,26,512]
    p4 = conv_block(feat2, filters=256, kernel_size=(1,1), strides=1)
    p4 = layers.concatenate([p4, p5_upsample])

    # 堆叠后进行5次卷积[26,26,512]==>[26,26,256]
    p4 = five_conv(p4, filters=256)

    #（2）
    # 对p4卷积上采样
    # [26,26,256]==>[26,26,512]==>[52,52,512]
    p4_upsample = conv_block(p4, filters=128, kernel_size=(1,1), strides=1)
    p4_upsample = layers.UpSampling2D(size=(2,2))(p4_upsample)

    # feat1层卷积后与p4_upsample堆叠
    # [52,52,256]==>[52,52,128]==>[52,52,256]
    p3 = conv_block(feat1, filters=128, kernel_size=(1,1), strides=1)
    p3 = layers.concatenate([p3, p4_upsample])

    # 堆叠后进行5次卷积[52,52,256]==>[52,52,128]
    p3 = five_conv(p3, filters=128)

    # 存放第一个特征层的输出
    p3_output = p3

    #（3）
    # p3卷积下采样和p4堆叠
    # [52,52,128]==>[26,26,256]==>[26,26,512]
    p3_downsample = conv_block(p3, filters=256, kernel_size=(3,3), strides=2)
    p4 = layers.concatenate([p3_downsample, p4])

    # 堆叠后的结果进行5次卷积[26,26,512]==>[26,26,256]
    p4 = five_conv(p4, filters=256)

    # 存放第二个有效特征层的输出
    p4_output = p4

    #（4）
    # p4卷积下采样和p5堆叠
    # [26,26,256]==>[13,13,512]==>[13,13,1024]
    p4_downsample = conv_block(p4, filters=512, kernel_size=(3,3), strides=2)
    p5 = layers.concatenate([p4_downsample, p5])

    # 堆叠后进行5次卷积[13,13,1024]==>[13,13,512]
    p5 = five_conv(p5, filters=512)

    # 存放第三个有效特征层的输出
    p5_output = p5

    # 返回输出层结果
    return p3_output, p4_output, p5_output

# 验证
if __name__ == '__main__':

    inputs = keras.Input(shape=[416,416,3])
    p3_output, p4_output, p5_output = panet(inputs)

    print('p3.shape:', p3_output.shape,  # (None, 52, 52, 128)
          'p4.shape:', p4_output.shape,  # (None, 26, 26, 256)
          'p5.shape:', p5_output.shape)  # (None, 13, 13, 512)